Испольуется три задачи:
Сеть состоит из lif AdEx нейронов
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from cgtasknet.net.lifadex import SNNlifadex
from cgtasknet.tasks.reduce import (
CtxDMTaskParameters,
DMTaskParameters,
DMTaskRandomModParameters,
MultyReduceTasks,
RomoTaskParameters,
RomoTaskRandomModParameters,
)
from norse.torch.functional.lif_adex import LIFAdExParameters
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
device=device(type='cuda', index=0)
import os
def plot_results(inputs, target_outputs, outputs):
if isinstance(inputs, torch.Tensor) and isinstance(target_outputs, torch.Tensor):
inputs, t_outputs = (
inputs.detach().cpu().numpy(),
target_outputs.detach().cpu().numpy(),
)
for bath in range(min(batch_size, 20)):
fig = plt.figure(figsize=(15, 3))
ax1 = fig.add_subplot(141)
plt.title("Inputs")
plt.xlabel("$time, ms$")
plt.ylabel("$Magnitude$")
for i in range(3):
plt.plot(inputs[:, bath, i].T, label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax2 = fig.add_subplot(142)
plt.title("Task code (context)")
plt.xticks(np.arange(1, len(tasks) + 1), sorted(tasks), rotation=90)
plt.yticks([])
for i in range(3, inputs.shape[-1]):
plt.plot([i - 2] * 2, [0, inputs[0, bath, i]])
plt.tight_layout()
ax3 = fig.add_subplot(143)
plt.title("Target output")
plt.xlabel("$time, ms$")
for i in range(t_outputs.shape[-1]):
plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
plt.legend()
plt.tight_layout()
ax4 = fig.add_subplot(144)
plt.title("Real output")
plt.xlabel("$time, ms$")
for i in range(outputs.shape[-1]):
plt.plot(
outputs.detach().cpu().numpy()[:, bath, i], label=rf"$out_{i + 1}$"
)
plt.legend()
plt.tight_layout()
if not os.path.exists("figures"):
os.mkdir("figures")
plt.savefig(f"figures{os.sep}network_outputs_{name}_batch_{bath}.pdf")
plt.show()
plt.close()
batch_size = 100
number_of_epochs = 2000
number_of_tasks = 1
romo_parameters = RomoTaskRandomModParameters(
romo=RomoTaskParameters(
delay=0.1,
positive_shift_delay_time=1.4,
trial_time=0.1,
positive_shift_trial_time=0.2,
),
)
dm_parameters = DMTaskRandomModParameters(
dm=DMTaskParameters(trial_time=0.1, positive_shift_trial_time=0.8)
)
ctx_parameters = CtxDMTaskParameters(dm=dm_parameters.dm)
sigma = 0.5
tasks = ["RomoTask1", "RomoTask2", "DMTask1", "DMTask2", "CtxDMTask1", "CtxDMTask2"]
task_dict = {
tasks[0]: romo_parameters,
tasks[1]: romo_parameters,
tasks[2]: dm_parameters,
tasks[3]: dm_parameters,
tasks[4]: ctx_parameters,
tasks[5]: ctx_parameters,
}
Task = MultyReduceTasks(
tasks=task_dict, batch_size=batch_size, delay_between=0, enable_fixation_delay=True
)
print("Task parameters:")
for key in task_dict:
print(f"{key}:\n{task_dict[key]}\n")
print(f"inputs/outputs: {Task.feature_and_act_size[0]}/{Task.feature_and_act_size[1]}")
Task parameters: RomoTask1: RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2) RomoTask2: RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2) DMTask1: DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2) DMTask2: DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2) CtxDMTask1: CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None)) CtxDMTask2: CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None)) inputs/outputs: 9/3
inputs, t_outputs = Task.dataset(n_trials=1)
for bath in range(min(batch_size, 10)):
fig = plt.figure(figsize=(15, 3))
ax1 = fig.add_subplot(131)
plt.title("Inputs")
plt.xlabel("$time, ms$")
plt.ylabel("$Magnitude$")
for i in range(3):
plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax2 = fig.add_subplot(132)
plt.title("Task code (context)")
plt.xlabel("$time, ms$")
for i in range(3, inputs.shape[-1]):
plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax3 = fig.add_subplot(133)
plt.title("Target output")
plt.xlabel("$time, ms$")
for i in range(t_outputs.shape[-1]):
plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
plt.legend()
plt.tight_layout()
plt.show()
plt.close()
del inputs
del t_outputs
feature_size, output_size = Task.feature_and_act_size
hidden_size = 450
neuron_parameters = LIFAdExParameters(
v_th=torch.as_tensor(0.65),
tau_ada_inv=0.5 + (6 - 0.5) * torch.rand(hidden_size).to(device),
alpha=100,
method="super",
# rho_reset = torch.as_tensor(5)
)
model = SNNlifadex(
feature_size,
hidden_size,
output_size,
neuron_parameters=neuron_parameters,
tau_filter_inv=500,
).to(device)
learning_rate = 1e-2
class RMSELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, yhat, y):
return torch.sqrt(self.mse(yhat, y))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Если память не позволяет, то необходимо генерировать каждую эпоху в основном цикле обучения
if False:
list_inputs = []
list_t_outputs = []
for i in tqdm(range(number_of_epochs)):
temp_input, temp_t_output = Task.dataset()
temp_input.astype(dtype=np.float16)
temp_t_output.astype(dtype=np.float16)
temp_input[:, :, :] += np.random.normal(0, sigma, size=temp_input.shape)
list_inputs.append(temp_input)
list_t_outputs.append(temp_t_output)
TODO: Необходимо добавить в cgtasknet и вызывать оттуда
from numba import njit, prange
import time
@njit(cache=True, parallel=True)
def every_bath_generator(
start_sigma: float,
stop_sigma: float,
times: int = 1,
batches: int = 1,
actions: int = 1,
):
data = np.zeros((times, batches, actions))
for i in prange(batches):
data[:, i, :] = np.random.normal(
0, np.random.uniform(start_sigma, stop_sigma), size=(times, actions)
)
return data
every_bath_generator(0, 0)
array([[[0.]]])
from cgtasknet.instruments.instrument_accuracy_network import correct_answer
from cgtasknet.net.states import LIFAdExRefracInitState
name = f"Train_dm_and_romo_task_reduce_lif_adex_without_refrac_random_delay_long_a_alpha_{neuron_parameters.alpha}_N_{hidden_size}"
init_state = LIFAdExRefracInitState(batch_size, hidden_size, device=device)
running_loss = 0
for i in tqdm(range(2000)):
inputs, target_outputs = Task.dataset()
inputs[:, :, :3] += every_bath_generator(
0, sigma, inputs.shape[0], inputs.shape[1], 3
)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
optimizer.zero_grad()
# forward + backward + optimize
outputs, _ = model(inputs)
loss = criterion(outputs, target_outputs)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 10 == 9:
with open("log_multy.txt", "a") as f:
f.write("epoch: {:d} loss: {:0.5f}\n".format(i + 1, running_loss / 10))
running_loss = 0.0
with torch.no_grad():
torch.save(
model.state_dict(),
name,
)
if i % 10 == 9:
result = 0
for j in range(10):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs += np.random.normal(0, 0.01, size=inputs.shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = (
torch.from_numpy(target_outputs).type(torch.float).to(device)
)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 10 * 100
with open("accuracy_multy.txt", "a") as f:
f.write(f"ecpoch = {i}; correct/all = {accuracy}\n")
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
print("Finished Training")
100%|██████████| 2000/2000 [3:28:05<00:00, 6.24s/it]
Finished Training
def test_network(test_sigma:float, number_of_trials: int = 100, plot_data:bool=True):
result = 0
for j in tqdm(range(number_of_trials)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += every_bath_generator(0, test_sigma, inputs.shape[0], inputs.shape[1], 3)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / number_of_trials * 100
if plot_data:
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
return accuracy
np.random.normal(0, 0.01, size=(inputs.shape))
accuracy = test_network(0.01, 100)
print(accuracy)
100%|██████████| 100/100 [03:26<00:00, 2.06s/it]
96.03
np.random.normal(0, 0.05, size=(inputs.shape))
accuracy = test_network(0.05)
print(accuracy)
100%|██████████| 100/100 [03:20<00:00, 2.00s/it]
95.37
np.random.normal(0, 0.1, size=(inputs.shape))
accuracy = test_network(0.1)
print(accuracy)
100%|██████████| 100/100 [03:21<00:00, 2.01s/it]
95.38
np.random.normal(0, 0.5, size=(inputs.shape))
accuracy = test_network(0.5)
print(accuracy)
100%|██████████| 100/100 [03:19<00:00, 1.99s/it]
94.15
result = 0
for j in tqdm(range(1)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.5, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 1/1 [00:02<00:00, 2.26s/it]
0.89
result = 0
for j in tqdm(range(1)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.7, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 1/1 [00:02<00:00, 2.35s/it]
0.83
inputs = 0
outputs = 0
tau_ada_inv_distrib = neuron_parameters.tau_ada_inv.cpu().numpy()
np.save(f"tau_ada_inv_alpha={neuron_parameters.alpha}", tau_ada_inv_distrib)
lines = []
with open("accuracy_multy.txt", "r") as f:
while line := f.readline():
lines.append(float(line.split("=")[2]))
plt.figure(figsize=(8, 5))
plt.plot([*range(9, 2000, 10)], lines, ".", linestyle="--", markersize=5)
plt.ylabel(r"Accuracy%")
plt.xlabel(r"Epochs")
Text(0.5, 0, 'Epochs')
start_sigma = 0
stop_sigma = 2
step_sigma = 0.05
sigma_array = np.arange(start_sigma, stop_sigma, step_sigma)
for test_sigma in tqdm(sigma_array):
result = 0
for j in range(20):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, test_sigma, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 20 * 100
with open("accuracy_vs_noise.txt", "a") as f:
f.write(f"noise={test_sigma}:accuracy={accuracy}\n")
100%|██████████| 40/40 [28:18<00:00, 42.46s/it]
import matplotlib.patches as patches
plt.style.use("ggplot")
def parser(line_text: str) -> tuple:
"""
Function parses text in form:
```v_name_1=v1:v_name_2:v2```
and return (v1, v2)
:param line_text:
:return: (v1, v2)
"""
line_text = line_text.split(":")
v1 = line_text[0].split("=")[1]
v2 = line_text[1].split("=")[1]
return float(v1), float(v2)
x, y = [], []
# with open('accuracy_vs_noise.txt', 'r') as f:
with open(
r"A:\src\multy_task\notebooks\train\reduce\lif_adex\romo_dm_ctx\accuracy_vs_noise.txt",
"r",
) as f:
while line := f.readline():
t_x, t_y = parser(line)
x.append(t_x)
y.append(t_y)
fig, ax = plt.subplots()
ax.plot(x, y, ".", linestyle="--")
# ax.plot([.5]*2, [50, 100])
ax.set_ylabel("Accuracy,%")
ax.set_xlabel(r"$\sigma$")
ax.add_patch(
patches.Rectangle(
(0, 50), 0.5, 50, edgecolor="grey", facecolor="grey", alpha=0.5, fill=True
)
)
plt.show()
plt.close()